{# ----------------------------------------------------------------------------
 # SymForce - Copyright 2022, Skydio, Inc.
 # This source code is under the Apache 2.0 license found in the LICENSE file.
 # ---------------------------------------------------------------------------- #}

{# ------------------------------------------------------------------------- #}
{# Utilities for Rust code generation templates.                              #}
{# ------------------------------------------------------------------------- #}

{# Format the scalar type
 #}
{%- macro format_scalar(spec) -%}
    {%- if spec.config.scalar_type.name == "FLOAT" -%}
        f32
    {%- elif spec.config.scalar_type.name == "DOUBLE" -%}
        f64
    {%- else -%}
        {{ raise("Unsupported scalar type: {}".format(spec.config.scalar_type)) }}
    {%- endif -%}
{%- endmacro -%}

{%- macro format_vector(T, name, spec) -%}
    {% set size = T.SHAPE[0] * T.SHAPE[1] %}
    nalgebra::SVector::<{{ format_scalar(spec) }}, {{ size }}>
{%- endmacro -%}

{%- macro format_matrix(T, name, spec) -%}
    nalgebra::SMatrix::<{{ format_scalar(spec) }}, {{ T.SHAPE[0] }}, {{ T.SHAPE[1] }}>
{%- endmacro -%}

{%- macro format_typename(name, type, spec) -%}
    {% set T = typing_util.get_type(type) %}
    {% if is_symbolic(type) or T.__name__ == "float" %}
        {{ format_scalar(spec) }}
    {% else %}
        {% if issubclass(T, Matrix) %}
            {% if T.SHAPE[0] == 1 or T.SHAPE[1] == 1 %}
                {{ format_vector(T, name, spec) }}
            {% else %}
                {{ format_matrix(T, name, spec) }}
            {% endif %}
        {% else %}
            {{ raise("Unsupported type {} for value {}".format(T, name)) }}
        {% endif %}
    {% endif %}
{%- endmacro -%}

{# ------------------------------------------------------------------------- #}

{# Get the type of the object in the output Values with key given by spec.return_key
 #
 # Args:
 #     spec (Codegen):
 #}
{%- macro get_return_type(spec) -%}
    {%- if spec.return_key is not none -%}
        {% set type = spec.outputs[spec.return_key] %}
        {% set T = typing_util.get_type(type) %}
        {% if is_symbolic(type) or T.__name__ == "float" %}
            {{ format_scalar(spec) }}
        {% elif issubclass(T, Matrix) %}
            {{ format_vector(T, spec.return_key, spec) }}
        {% else %}
            {{ raise("Unsupported return type: {}".format(T)) }}
        {% endif %}
    {%- else -%}
        ()
    {%- endif -%}
{%- endmacro -%}

{# ------------------------------------------------------------------------- #}

 {# Format function docstring
 #
 # Args:
 #     docstring (str):
 #}
{% macro print_docstring(docstring) %}
{%- if docstring %}


{%- for line in docstring.split('\n') %}
///{{ ' {}'.format(line).rstrip() }}
{% endfor -%}
{%- endif -%}
{% endmacro %}

{# ------------------------------------------------------------------------- #}

{# Format function input argument
 #
 # Args:
 #     T_or_value (type or Element):
 #     name (str):
 #}
{%- macro format_input_arg(name, T_or_value, spec) %}
    {%- set T = typing_util.get_type(T_or_value) -%}
    {% if T.__name__ == "Symbol" %}
     {{ name }}: {{ format_scalar(spec) }}
    {% elif issubclass(T, Matrix) %}
        {{ name }}: &{{ format_typename(name, T_or_value, spec) }}
    {% elif issubclass(T, Matrix) -%}
        {{ name }}: &{{ format_typename(name, T_or_value, spec) }}
    {%- else -%}
        {{ raise('Unsupported type {} for input "{}"'.format(T, name)) }}
    {%- endif %}
{% endmacro -%}

{# ------------------------------------------------------------------------- #}

{# Format function pointer argument
 #
 # Args:
 #     name (str):
 #     add_default (bool): Include a default to nullptr?
 #}
{%- macro format_output_arg(name, type, spec) %}
    {{ name }}: Option<&mut {{- format_typename(name, type, spec) | indent(width=0) -}}>
{%- endmacro -%}

{# ------------------------------------------------------------------------- #}

{# Generate input arguments declaration.
 #
 # Args:
 #     spec (Codegen):
 #}
{%- macro input_args_declaration(spec) %}
    {%- for name, type in spec.inputs.items() -%}
        {{ format_input_arg(name, type, spec) | indent(width=0) }}
        {%- if not loop.last
            or spec.outputs.items() | length > 1
            or (spec.outputs.items() | length == 1 and spec.return_key is none) -%}
        , {% endif -%}
    {%- endfor -%}
    {%- for name, type in spec.outputs.items() -%}
        {%- if name != spec.return_key -%}
            {{- format_output_arg(name, type, spec) | indent(width=0) -}}
            {%- if not loop.last -%}
                {%- if not (loop.revindex0 == 1 and loop.nextitem[0] == spec.return_key) -%}
                , {% endif -%}
            {% endif -%}
        {%- endif -%}
    {%- endfor -%}
{% endmacro -%}

{# ------------------------------------------------------------------------- #}

{# Generate function declaration
 #
 # Args:
 #     spec (Codegen):
 #}
{%- macro function_declaration(spec) -%}
{% set name = spec.name %}
    pub fn {{ name }}({{- input_args_declaration(spec) | indent(width=4) -}}) -> {{ get_return_type(spec) | indent(width=8)}}
{%- endmacro -%}

{# ------------------------------------------------------------------------- #}

{# Helper to generate code to fill out an output returned as an output pointer
 #
 # Args:
 #     name (str): Name of the output object
 #     type (type): Type of the output object
 #     terms (List[Tuple[str]]): List of output terms for this object
 #}
{% macro format_output(name, type, terms, spec) -%}
{% set T = typing_util.get_type(type) %}
{% if issubclass(T, Matrix) %}
    {% for lhs, rhs in terms %}
        {{ lhs }} = {{ rhs }};
    {% endfor %}
{% elif is_symbolic(type) or T.__name__ == "float" %}
    *{{ name }} = {{ terms[0][1] }};
{% else %}
    {{ raise('Unsupported return type for Rust: name="{}", type=`{}`'.format(name, T)) }}
{% endif %}
{% endmacro -%}

{# ------------------------------------------------------------------------- #}

{# Helper to generate all pointer and vector outputs #}
{% macro format_outputs(spec) -%}
    {% for name, type, terms in spec.print_code_results.dense_terms %}
        {% if name != spec.return_key %}
    if let Some( {{ name }} ) = {{ name }} {

        {% set T = typing_util.get_type(type) %}
        {% if issubclass(T, Matrix) %}
        *{{ name }} = {{ format_typename(name, type, spec) }}::zeros();
        {% endif %}
        {{ format_output(name, type, terms, spec) | indent(width=8) | trim }}
    }
        {% endif %}

    {% endfor %}
    {% for name, type, terms in spec.print_code_results.sparse_terms %}
        {{ raise("Cannot return sparse output {}; Rust does not support sparse matrices".format(name)) }}
    {% endfor %}
{% endmacro -%}

{# ------------------------------------------------------------------------- #}

{# Generate inner code for computing the given expression.
 #
 # Args:
 #     spec (Codegen):
 #}
{% macro expr_code(spec) -%}
    // Total ops: {{ spec.total_ops() }}

    {% if spec.unused_arguments %}
    // Unused inputs
    {% for arg in spec.unused_arguments %}
    // {{ arg }};
    {% endfor %}

    {% endif %}
    // Intermediate terms ({{ spec.print_code_results.intermediate_terms | length }})
    {% for lhs, rhs in spec.print_code_results.intermediate_terms %}
    let {{ lhs }}: {{ format_scalar(spec) }} = {{ rhs }};
    {% endfor %}

    // Output terms ({{ spec.outputs.items() | length }})
    {{ format_outputs(spec) | trim }}
    {# Return the return_key if present #}
    {% for name, type, terms in spec.print_code_results.dense_terms %}
        {% set T_return = typing_util.get_type(type) %}
        {% if name == spec.return_key and T_return.__name__ != 'NoneType' %}
            {% if is_symbolic(type) or T_return.__name__ == "float" %}
            {{ terms[0][1] }};
            {% elif issubclass(T_return, Matrix) %}
                {% set size = T_return.SHAPE[0] * T_return.SHAPE[1] %}
                {{ format_vector(T_return, name, spec) }}::new(
                {% for lhs, rhs in terms %}
        {{ rhs }}{% if not loop.last %},{% endif %}
                {% endfor %}
    )
            {% else %}
            {{ raise("Cannot return an output of type {}: The Rust backend only supports returning scalars, vectors, and matrices.".format(T_return)) }}
            {% endif %}
        {% endif %}
    {% endfor %}
{%- endmacro -%}
